import torch
import os
import random
import json
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets

# from transformers import AutoModelForCausalLM 
from modeling_phi3_v import Phi3VForCausalLM, Phi3Attention
from transformers import AutoProcessor 

from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
from utils.model_utils import phi3_image_processor, call_phi3_engine_df
from utils.eval_utils import parse_multi_choice_response, parse_open_response
from argparse import ArgumentParser

TAGET_MODULE = {
    "phi3": None,
    # "phi3_h2o": None
    "phi3_h2o": Phi3Attention
}

# 加载模型和处理器
model_id = "microsoft/Phi-3.5-vision-instruct"
model = Phi3VForCausalLM.from_pretrained(
         model_id, 
         device_map="cuda", 
         trust_remote_code=True, 
         torch_dtype="auto", 
         _attn_implementation='eager'
    )
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, num_crops=4)

# 生成图像描述（每次上传一张图像）
def generate_caption_for_image(image):
    """
    为每张图像生成描述
    """
    # 生成图像的占位符标签
    placeholder = "<|image_1|>\n"
    
    # 生成请求文本
    query = (
        "Please generate a story description for the uploaded image. The requirements are:\n"
        "- The description for the image should be consistent with the style of the 'Rabbids' cartoon, ensuring that the text and visuals are aligned in terms of style.\n"
        "- The description should be engaging and entertaining, with elements that captivate the audience and maintain their interest in the story.\n"
        "- The description should be closely related to the content of the image, ensuring that the text is coherent with the visuals and maintains logical consistency in the narrative.\n"
    )
    
    # 创建输入消息
    messages = [{"role": "user", "content": placeholder + query}]
    
    # 创建输入提示
    prompt = processor.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # 处理输入
    inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0")
    
    # 设置生成参数
    generation_args = { 
        "max_new_tokens": 196,  # 设置最大生成文本长度
        "temperature": 0.7, 
        "do_sample": True, 
    } 
    
    # 生成模型输出
    generate_ids = model.generate(**inputs, 
                                  eos_token_id=processor.tokenizer.eos_token_id, 
                                  **generation_args)
    for name, m in model.named_modules():
        if isinstance(m, TAGET_MODULE["phi3_h2o"]):
            m._clean_cache()
    
    # 移除输入的tokens
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    
    response = processor.batch_decode(generate_ids, 
                                      skip_special_tokens=True, 
                                      clean_up_tokenization_spaces=False)
    
    return response

# 读取数据集
def load_data(file_path):
    """
    加载并解析JSONL格式的数据集
    """
    with open(file_path, 'r') as file:
        data = [json.loads(line) for line in file]
    return data

# 处理数据集中的图像并生成长文本描述（每次上传一张图像）
def process_dataset(data):
    """
    遍历数据集，每次上传一张图像生成描述
    """
    all_results = []
    
    # Add a tqdm progress bar for the dataset
    for item in tqdm(data, desc="Processing dataset", unit="item"):  # Limit to the first item for debugging
        # 获取当前条目的图片路径列表
        images = item['images']
        images = [os.path.join('/hy-tmp/Rabbids/rabbids', image_path) for image_path in images]
        
        # 每次上传一张图片
        for image_path in tqdm(images, desc="Processing images", unit="image", leave=False):  # Add progress bar for images
            image = Image.open(image_path).resize((256, 256))
            generated_caption = generate_caption_for_image(image)
            all_results.append({"generated_caption": generated_caption})
        
    return all_results

# 主程序入口
def main():
    # 数据文件路径
    file_path = '/hy-tmp/Rabbids/val.jsonl'  # 请替换为你实际的路径
    
    # 加载数据集
    data = load_data(file_path)
    
    # 处理数据集并生成结果
    results = process_dataset(data)
    
    # 输出生成的结果到文件
    output_file = os.path.join(os.path.dirname(file_path), "generated_captions_EBM.jsonl")
    
    with open(output_file, 'w') as file:
        for result in results:
            file.write(json.dumps(result) + "\n")
    
    print(f"Results saved to: {output_file}")

# 调用主程序
if __name__ == "__main__":
    main()
